1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.lucene.classification;
18
19 import org.apache.lucene.analysis.Analyzer;
20 import org.apache.lucene.analysis.TokenStream;
21 import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
22 import org.apache.lucene.document.Document;
23 import org.apache.lucene.index.LeafReader;
24 import org.apache.lucene.index.IndexableField;
25 import org.apache.lucene.index.MultiFields;
26 import org.apache.lucene.index.Term;
27 import org.apache.lucene.index.Terms;
28 import org.apache.lucene.index.TermsEnum;
29 import org.apache.lucene.search.BooleanClause;
30 import org.apache.lucene.search.BooleanQuery;
31 import org.apache.lucene.search.IndexSearcher;
32 import org.apache.lucene.search.Query;
33 import org.apache.lucene.search.ScoreDoc;
34 import org.apache.lucene.search.WildcardQuery;
35 import org.apache.lucene.util.BytesRef;
36 import org.apache.lucene.util.BytesRefBuilder;
37 import org.apache.lucene.util.IntsRefBuilder;
38 import org.apache.lucene.util.fst.Builder;
39 import org.apache.lucene.util.fst.FST;
40 import org.apache.lucene.util.fst.PositiveIntOutputs;
41 import org.apache.lucene.util.fst.Util;
42
43 import java.io.IOException;
44 import java.util.List;
45 import java.util.Map;
46 import java.util.SortedMap;
47 import java.util.TreeMap;
48
49
50
51
52
53
54
55
56
57
58
59 public class BooleanPerceptronClassifier implements Classifier<Boolean> {
60
61 private Double threshold;
62 private final Integer batchSize;
63 private Terms textTerms;
64 private Analyzer analyzer;
65 private String textFieldName;
66 private FST<Long> fst;
67
68
69
70
71
72
73
74 public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
75 this.threshold = threshold;
76 this.batchSize = batchSize;
77 }
78
79
80
81
82
83
84
85 public BooleanPerceptronClassifier() {
86 batchSize = 1;
87 }
88
89
90
91
92 @Override
93 public ClassificationResult<Boolean> assignClass(String text)
94 throws IOException {
95 if (textTerms == null) {
96 throw new IOException("You must first call Classifier#train");
97 }
98 Long output = 0l;
99 try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
100 CharTermAttribute charTermAttribute = tokenStream
101 .addAttribute(CharTermAttribute.class);
102 tokenStream.reset();
103 while (tokenStream.incrementToken()) {
104 String s = charTermAttribute.toString();
105 Long d = Util.get(fst, new BytesRef(s));
106 if (d != null) {
107 output += d;
108 }
109 }
110 tokenStream.end();
111 }
112
113 return new ClassificationResult<>(output >= threshold, output.doubleValue());
114 }
115
116
117
118
119 @Override
120 public void train(LeafReader leafReader, String textFieldName,
121 String classFieldName, Analyzer analyzer) throws IOException {
122 train(leafReader, textFieldName, classFieldName, analyzer, null);
123 }
124
125
126
127
128 @Override
129 public void train(LeafReader leafReader, String textFieldName,
130 String classFieldName, Analyzer analyzer, Query query) throws IOException {
131 this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
132
133 if (textTerms == null) {
134 throw new IOException("term vectors need to be available for field " + textFieldName);
135 }
136
137 this.analyzer = analyzer;
138 this.textFieldName = textFieldName;
139
140 if (threshold == null || threshold == 0d) {
141
142 long sumDocFreq = leafReader.getSumDocFreq(textFieldName);
143 if (sumDocFreq != -1) {
144 this.threshold = (double) sumDocFreq / 2d;
145 } else {
146 throw new IOException(
147 "threshold cannot be assigned since term vectors for field "
148 + textFieldName + " do not exist");
149 }
150 }
151
152
153 SortedMap<String,Double> weights = new TreeMap<>();
154
155 TermsEnum termsEnum = textTerms.iterator();
156 BytesRef textTerm;
157 while ((textTerm = termsEnum.next()) != null) {
158 weights.put(textTerm.utf8ToString(), (double) termsEnum.totalTermFreq());
159 }
160 updateFST(weights);
161
162 IndexSearcher indexSearcher = new IndexSearcher(leafReader);
163
164 int batchCount = 0;
165
166 BooleanQuery.Builder q = new BooleanQuery.Builder();
167 q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, "*")), BooleanClause.Occur.MUST));
168 if (query != null) {
169 q.add(new BooleanClause(query, BooleanClause.Occur.MUST));
170 }
171
172 for (ScoreDoc scoreDoc : indexSearcher.search(q.build(),
173 Integer.MAX_VALUE).scoreDocs) {
174 Document doc = indexSearcher.doc(scoreDoc.doc);
175
176 IndexableField textField = doc.getField(textFieldName);
177
178
179 IndexableField classField = doc.getField(classFieldName);
180
181 if (textField != null && classField != null) {
182
183 ClassificationResult<Boolean> classificationResult = assignClass(textField.stringValue());
184 Boolean assignedClass = classificationResult.getAssignedClass();
185
186 Boolean correctClass = Boolean.valueOf(classField.stringValue());
187 long modifier = correctClass.compareTo(assignedClass);
188 if (modifier != 0) {
189 updateWeights(leafReader, scoreDoc.doc, assignedClass,
190 weights, modifier, batchCount % batchSize == 0);
191 }
192 batchCount++;
193 }
194 }
195 weights.clear();
196 }
197
198 @Override
199 public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
200 throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
201 }
202
203 private void updateWeights(LeafReader leafReader,
204 int docId, Boolean assignedClass, SortedMap<String, Double> weights,
205 double modifier, boolean updateFST) throws IOException {
206 TermsEnum cte = textTerms.iterator();
207
208
209 Terms terms = leafReader.getTermVector(docId, textFieldName);
210
211 if (terms == null) {
212 throw new IOException("term vectors must be stored for field "
213 + textFieldName);
214 }
215
216 TermsEnum termsEnum = terms.iterator();
217
218 BytesRef term;
219
220 while ((term = termsEnum.next()) != null) {
221 cte.seekExact(term);
222 if (assignedClass != null) {
223 long termFreqLocal = termsEnum.totalTermFreq();
224
225 Long previousValue = Util.get(fst, term);
226 String termString = term.utf8ToString();
227 weights.put(termString, previousValue + modifier * termFreqLocal);
228 }
229 }
230 if (updateFST) {
231 updateFST(weights);
232 }
233 }
234
235 private void updateFST(SortedMap<String,Double> weights) throws IOException {
236 PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
237 Builder<Long> fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
238 BytesRefBuilder scratchBytes = new BytesRefBuilder();
239 IntsRefBuilder scratchInts = new IntsRefBuilder();
240 for (Map.Entry<String,Double> entry : weights.entrySet()) {
241 scratchBytes.copyChars(entry.getKey());
242 fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
243 .getValue().longValue());
244 }
245 fst = fstBuilder.finish();
246 }
247
248
249
250
251 @Override
252 public List<ClassificationResult<Boolean>> getClasses(String text)
253 throws IOException {
254 throw new RuntimeException("not implemented");
255 }
256
257
258
259
260 @Override
261 public List<ClassificationResult<Boolean>> getClasses(String text, int max)
262 throws IOException {
263 throw new RuntimeException("not implemented");
264 }
265
266 }